# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# -------------------------------------------------------------------------
import json
import random
import uuid
from typing import Dict, Any, List, Tuple

from azure.cosmos import PartitionKey
from azure.cosmos import CosmosClient
import azure.cosmos.exceptions as exceptions

import config
from azure.identity import DefaultAzureCredential
from azure.cosmos.http_constants import HttpHeaders

# ----------------------------------------------------------------------------------------------------------
# Prerequisites -
#
# 1. An Azure Cosmos account -
#    https://azure.microsoft.com/documentation/articles/documentdb-create-account/
#
# 2. Microsoft Azure Cosmos PyPi package -
#    https://pypi.python.org/pypi/azure-cosmos/
# ----------------------------------------------------------------------------------------------------------
# Sample - demonstrates how to manage session tokens. By default, the SDK manages session tokens for you. These samples
# are for use cases where you want to manage session tokens yourself.
#
# 1. Storing session tokens in a cache by feed range from the partition key.
#
# 2. Storing session tokens in a cache by feed range from the container.
#
# ----------------------------------------------------------------------------------------------------------
# Note -
#
# Running this sample will create (and delete) multiple Containers on your account.
# Each time a Container is created the account will be billed for 1 hour of usage based on
# the provisioned throughput (RU/s) of that account.
# ----------------------------------------------------------------------------------------------------------

HOST = config.settings['host']
CREDENTIAL = DefaultAzureCredential()
DATABASE_ID = config.settings['database_id']
CONTAINER_ID = config.settings['container_id']

def storing_session_tokens_pk(container):
    print('1. Storing session tokens in a cache by feed range from the partition key.')


    cache: Dict[str, Any] = {}

    # Everything below is just a simulation of what could be run on different machines and clients
    # to store session tokens in a cache by feed range from the partition key.
    # The cache is a Dict here for simplicity but in a real-world scenario, it would be some service.
    feed_ranges_and_session_tokens: List[Tuple[Dict[str, Any], str]] = []

    # populating cache with session tokens
    for i in range(5):
        item = {
            'id': 'item' + str(uuid.uuid4()),
            'name': 'sample',
            'pk': 'A' + str(random.randint(1, 10))
        }
        target_feed_range = container.feed_range_from_partition_key(item['pk'])
        perform_create_item_with_cached_session_token(cache, container, feed_ranges_and_session_tokens, item,
                                                      target_feed_range)

def perform_create_item_with_cached_session_token(cache, container, feed_ranges_and_session_tokens, item,
                                                  target_feed_range):
    # only doing this for the key to be immutable
    feed_range_json = json.dumps(target_feed_range)
    session_token = cache[feed_range_json] if feed_range_json in cache else None
    response = container.create_item(item, session_token=session_token)
    response_session_token = response.get_response_headers()[HttpHeaders.SessionToken]
    # adding everything from the cache in case consolidation is possible
    for feed_range_json, session_token_cache in cache.items():
        feed_range = json.loads(feed_range_json)
        feed_ranges_and_session_tokens.append((feed_range, session_token_cache))
    feed_ranges_and_session_tokens.append((target_feed_range, response_session_token))
    latest_session_token = container.get_latest_session_token(feed_ranges_and_session_tokens, target_feed_range)
    # only doing this for the key to be immutable
    cache[feed_range_json] = latest_session_token

def storing_session_tokens_container_feed_ranges(container):
    print('2. Storing session tokens in a cache by feed range from the container.')

    # The cache is a dictionary here for simplicity but in a real-world scenario, it would be some service.
    cache: Dict[str, Any] = {}

    # Everything below is just a simulation of what could be run on different machines and clients
    # to store session tokens in a cache by feed range from the partition key.
    feed_ranges_and_session_tokens: List[Tuple[Dict[str, Any], str]] = []
    feed_ranges = list(container.read_feed_ranges())

    # populating cache with session tokens
    for i in range(5):
        item = {
            'id': 'item' + str(uuid.uuid4()),
            'name': 'sample',
            'pk': 'A' + str(random.randint(1, 10))
        }
        feed_range_from_pk = container.feed_range_from_partition_key(item['pk'])
        target_feed_range: dict = next(
            (feed_range for feed_range in feed_ranges if container.is_feed_range_subset(feed_range, feed_range_from_pk)),
            {}
        )
        perform_create_item_with_cached_session_token(cache, container, feed_ranges_and_session_tokens, item, target_feed_range)

def run_sample():
    with CosmosClient(HOST, CREDENTIAL) as client:
        try:
            db = client.create_database_if_not_exists(id=DATABASE_ID)
            container = db.create_container_if_not_exists(id=CONTAINER_ID, partition_key=PartitionKey('/pk'))

            # example of storing session tokens in cache by feed range from the partition key
            storing_session_tokens_pk(container)

            # example of storing session tokens in cache by feed range from the container
            storing_session_tokens_container_feed_ranges(container)

            # cleanup database after sample
            try:
                client.delete_database(db)

            except exceptions.CosmosResourceNotFoundError:
                pass

        except exceptions.CosmosHttpResponseError as e:
            print('\nrun_sample has caught an error. {0}'.format(e.message))

        finally:
            print("\nrun_sample done")


if __name__ == '__main__':
    run_sample()
